#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Tue Apr 19 15:54:38 2022

@author: qiguangyao
"""


#%%Lib
import copy
import numpy as np
import pickle as pkl
import matplotlib.pyplot as plt
from scipy import stats
import seaborn as sns 
from scipy import asarray as ar,exp
from scipy.optimize import curve_fit
import math
import pingouin as pg
from sklearn import linear_model
from pylab import cos
import pandas as pd
import random
from statsmodels.formula.api import ols
from statsmodels.stats.anova import anova_lm
from statsmodels.sandbox.stats.multicomp import multipletests # for multiple comparisons correction
from statsmodels.stats.multicomp import pairwise_tukeyhsd
print("__file Output:",__file__)
#%%functions
import scipy.stats
def mean_confidence_interval(data, confidence=0.95):
    a = 1.0 * np.array(data)
    n = len(a)
    m, se = np.mean(a), scipy.stats.sem(a)
    h = se * scipy.stats.t.ppf((1 + confidence) / 2., n-1)
    return m, m-h, m+h

def adjust_spines(ax, spines):
    for loc, spine in ax.spines.items():
        if loc in spines:
            spine.set_position(('outward', 10))  # outward by 10 points
        else:
            spine.set_color('none')  # don't draw spine

    # turn off ticks where there is no spine
    if 'left' in spines:
        ax.yaxis.set_ticks_position('left')
    else:
        # no yaxis ticks
        ax.yaxis.set_ticks([])

    if 'bottom' in spines:
        ax.xaxis.set_ticks_position('bottom')
    else:
        # no xaxis ticks
        ax.xaxis.set_ticks([])
        
def gaus(x,a,x0,sigma):
    return a*(1/sigma*np.sqrt(2*np.pi))*exp(-(x-x0)**2/(2*sigma**2))

def gaussian(X, amp, cen, wid):
    return amp * exp(-(X-cen)**2 / wid)

def getPossionPDF(mu,x):
    if x > 170:
        x =170
    mu = mu + 0.01
    if x<0:
        x = 0
    # x[x<0]=0
    x = copy.deepcopy(round(x))
    out = math.exp(-mu)*(mu**x)/math.factorial(x)
    if out<0:
        out = 0
    return out

#tuning curve fitting
def vonMisesFunction(x,b,a,u):
    # import math
#    print(x - u)
    out = b + a*cos(x - u)
    out = np.array(out)
    out[out<0]=0
    # if out<0:
    #     out = 0
    return out

def getvonMisesParas(x,y):
    """
    x:hand position
    y:firing rate
    """
    init_vals = [1, 0, 1]  # for [b,a,u]
    best_vals, covar = curve_fit(vonMisesFunction, x, y, p0=init_vals,maxfev=500000)
    return best_vals

def getExpParas(x,y):
    """
    x:hand position
    y:firing rate
    """
    init_vals = [1, 0, 1]  # for [b,a,u]
    best_vals, covar = curve_fit(expFunction, x, y, p0=init_vals,maxfev=500000)
    return best_vals

def expFunction(x, a, b, c):
    return a * np.exp(-b * x) + c
#%% ------------figure 2---------------- 
fig2Data = pkl.load(open('fig2Data.pickle','rb'))

#2A left
pCommonMeanDisp =fig2Data['pCommonMeanDisp']
pCommonMeanHarryDisp =fig2Data['pCommonMeanHarryDisp']
pCommonMeanNicoDisp =fig2Data['pCommonMeanNicoDisp']
pCommonMeanSealDisp = fig2Data['pCommonMeanSealDisp']

#2A right
pCommonMean = fig2Data['pCommonMean']

#2B
harryHMMPara = fig2Data['harryHMMPara']
sealHMMPara = fig2Data['sealHMMPara']
nicoHMMPara = fig2Data['nicoHMMPara']
combHMMPara = np.vstack((harryHMMPara,sealHMMPara))
combHMMPara = np.vstack((combHMMPara,nicoHMMPara))
behaHMMCombSingMean = np.full((4,2),np.nan)
behaHMMCombSingMean[0,:] = np.nanmean(combHMMPara,0)
behaHMMCombSingMean[1,:] = np.nanmean(harryHMMPara,0)
behaHMMCombSingMean[2,:] = np.nanmean(nicoHMMPara,0)
behaHMMCombSingMean[3,:] = np.nanmean(sealHMMPara,0)

#2C
posthocsGroup = fig2Data['posthocsGroup']
posthocsTime = fig2Data['posthocsTime']
dataCollVPC = fig2Data['dataCollVPC']
dataCollVP = fig2Data['dataCollVP']
dataCollVPHarry = fig2Data['dataCollVPHarry']
dataCollVPCHarry = fig2Data['dataCollVPCHarry']
dataCollVPNico = fig2Data['dataCollVPNico']
dataCollVPCNico = fig2Data['dataCollVPCNico']
dataCollVPSeal = fig2Data['dataCollVPSeal']
dataCollVPCSeal = fig2Data['dataCollVPCSeal']
examVpPost = fig2Data['examVpPost']
examVpcPost =fig2Data['examVpcPost']


#2D
dataCombVPSealNorm = fig2Data['dataCombVPSealNorm']
dataCombVPCSealNorm = fig2Data['dataCombVPCSealNorm']
harryVPVPCSDNormClean = fig2Data['harryVPVPCSDNormClean']
nicoVPVPCSDNormClean = fig2Data['nicoVPVPCSDNormClean']
sealVPVPCSDNormClean = fig2Data['sealVPVPCSDNormClean']
VPVPCSDNormClean = fig2Data['VPVPCSDNormClean']
VPVPCSDNormRemoOutlWilc = fig2Data['VPVPCSDNormRemoOutlWilc']


examVpWithin = fig2Data['examVpWithin']
examVpcWithin = fig2Data['examVpcWithin']
#%%fig2A left
disparity = [-45,-35,-20,-10,0,10,20,35,45]
s1 = 30
ap1 = 1
with plt.style.context('style_paper.mplstyle'):
    plt.figure(figsize=[3.54/2,3.54/2])
    # plt.scatter([-45,-35,-20,-10,0,10,20,35,45],np.nanmean(pCommonMeanDisp,axis = 0),s = s1,color = 'k')#'#099d84'
    # plt.scatter(disps,np.nanmean(pCommonMeanHarryDisp,axis = 0),color = 'gray',s = s1,alpha = ap1,edgecolor = [])#'#099d84'
    # plt.scatter(disps,np.nanmean(pCommonMeanNicoDisp,axis = 0),color = 'gray',s = s1,alpha = ap1,edgecolor = [])#'#099d84'
    # plt.scatter(disps,np.nanmean(pCommonMeanSealDisp,axis = 0),color = 'gray',s = s1,alpha = ap1,edgecolor = [])#'#099d84'
    plt.errorbar(np.unique(disparity),np.nanmean(pCommonMeanDisp,axis = 0), yerr = np.nanstd(pCommonMeanDisp,axis = 0)/np.sqrt(pCommonMeanDisp.shape[0]),
                 elinewidth = 2,
                 color = 'k',alpha = 1,linewidth = 1)
    plt.errorbar(np.unique(disparity),np.nanmean(pCommonMeanHarryDisp,axis = 0), yerr = np.nanstd(pCommonMeanHarryDisp,axis = 0)/np.sqrt(pCommonMeanHarryDisp.shape[0]),
                 elinewidth = 2,
                 ls ='--',color = 'gray',alpha = .6,linewidth = 1)
    plt.errorbar(np.unique(disparity),np.nanmean(pCommonMeanNicoDisp,axis = 0), yerr = np.nanstd(pCommonMeanNicoDisp,axis = 0)/np.sqrt(pCommonMeanNicoDisp.shape[0]),
                 elinewidth = 2,
                 ls ='--',color = 'gray',alpha = .6,linewidth = 1)
    plt.errorbar(np.unique(disparity),np.nanmean(pCommonMeanSealDisp,axis = 0), yerr = np.nanstd(pCommonMeanSealDisp,axis = 0)/np.sqrt(pCommonMeanSealDisp.shape[0]),
                 elinewidth = 2,
                 ls ='--',color = 'gray',alpha = .6,linewidth = 1)
    plt.xlabel('Disparity (deg)')
    plt.ylabel('$P_{com}$')
    plt.tight_layout()
    fileName = 'fig2A_meanPCommonAverCI.pdf'
    # plt.savefig(fileName,dpi = 600)
plt.show()
#%%fig2A right
disp = np.array([-45,-35,-20,-10,0,10,20,35,45])
drif = np.array([45,35,20,10,0,-10,-20,-35,-45])
cm = plt.cm.get_cmap('jet')
with plt.style.context('style_paper.mplstyle'):
    plt.figure(figsize=[3.54/1.6,3.54/2])
    for i in range(pCommonMean.shape[0]):
        for j in range(pCommonMean.shape[1]):
            sc = plt.scatter(drif[j],
                                     disp[i],
                                     s = 100, 
                                     c = pCommonMean[i,j],
                                     vmin = .0, 
                                     vmax = 1,
                                     cmap = cm)
    cb = plt.colorbar(sc)
    cb.outline.set_visible(False)
    plt.xlabel('Disparity (deg)')
    plt.ylabel('Drift (deg)')
    plt.tight_layout()
    fileName = 'fig2A_pCommonMean.pdf'
    # plt.savefig(fileName,dpi = 600)
plt.show()
#%% fig2B boxplot
inpuData = copy.deepcopy(combHMMPara)
tansDF = pd.DataFrame()
widt1 = .15
barWidth = 0.2
r1 = np.arange(1)+barWidth
r2 = np.array([x + barWidth for x in r1])
s1 = 30
lw1 = .25
marker='s'
marker='d'
'o'
tansDF['trans'] = list(inpuData[:,0])+list(inpuData[:,1])
tansDF['condition'] = ['c1c1' for i in range(len(inpuData))]+['c1c2' for i in range(len(inpuData))]
with plt.style.context('style_paper.mplstyle'):
    colors = plt.rcParams["axes.prop_cycle"].by_key()["color"]
    fig, ax = plt.subplots(figsize=(3.54/2,3.54/2))
    # ax = sns.boxplot(x="condition", y="trans", showfliers=False,
    #                  data=tansDF,boxprops=dict(alpha=.5),
    #                  width = .5)
    # jets1 = [(0.5-random.random())/7 for i in range(len(inpuData))]
    jets1 = [0 for i in range(len(inpuData))]

    
    
    
    plt.scatter([jets1[i]+r1 for i in range(len(inpuData))],inpuData[:,0],
                alpha = .5,
                s = s1,
                color = [],
                marker = 'o',
                edgecolors = colors[0]
                )
    plt.scatter([jets1[i]+r2 for i in range(len(inpuData))],inpuData[:,1],
                alpha = .5,
                s = s1,
                color = [],
                marker = 'o',
                edgecolors = colors[1])
    #combine data
    bplot1 = plt.boxplot(inpuData[:,[0]],positions = r1,
                showfliers=False,
                widths = widt1,
                medianprops = {'color':'k'},
                patch_artist=True)
    bplot2 = plt.boxplot(inpuData[:,[1]],positions = r2,
                showfliers=False,
                widths = widt1,
                medianprops = {'color':'k',
                               # 'linewidth':.5
                               },
                patch_artist=True)

    for patch in bplot1['boxes']:
        patch.set(color=colors[0])
        patch.set(facecolor=colors[0])
        patch.set(alpha = .5)
    
    for patch in bplot2['boxes']:
        patch.set(color=colors[1])
        patch.set(facecolor=colors[1])
        patch.set(alpha = .5)        
        
    plt.xticks([r1[0],r2[0]],['$P_{(C1|C1)}$','$P_{(C1|C2)}$'])
        
    plt.xlim([r1[0]-barWidth/2.5,r2[0]+barWidth/2.5])

    
    plt.ylim(top = 1.01)
    plt.ylim(bottom = 0.8)
    plt.xlabel(None)
    plt.ylabel('Transition prob.')
    adjust_spines(ax, ['left', 'bottom'])
    plt.tight_layout()
    fileName = 'fig2B_behaHMMComb1.pdf'
    # plt.savefig(fileName,dpi = 600)
plt.show()
#%% fig2B scatter mini
inpuData = copy.deepcopy(combHMMPara)
tansDF = pd.DataFrame()
widt1 = .15
barWidth = 0.2
r1 = np.arange(1)+barWidth
r2 = np.array([x + barWidth for x in r1])
s1 = 30/2
lw1 = .25
with plt.style.context('style_paper.mplstyle'):
    colors = plt.rcParams["axes.prop_cycle"].by_key()["color"]
    fig, ax = plt.subplots(figsize=(3.54/2/2,3.54/2/2))
    jets1 = [0 for i in range(len(inpuData))]
    plt.scatter([jets1[i]+r1 for i in range(len(inpuData))],inpuData[:,0],
                alpha = .5,
                s = s1,
                color = [],
                marker = 'o',
                edgecolors = colors[0]
                )
    plt.scatter([jets1[i]+r2 for i in range(len(inpuData))],inpuData[:,1],
                alpha = .5,
                s = s1,
                color = [],
                marker = 'o',
                edgecolors = colors[1])
    bplot1 = plt.boxplot(inpuData[:,[0]],positions = r1,
                showfliers=False,
                widths = widt1,
                medianprops = {'color':'k'},
                patch_artist=True)
    bplot2 = plt.boxplot(inpuData[:,[1]],positions = r2,
                showfliers=False,
                widths = widt1,
                medianprops = {'color':'k',
                               },
                patch_artist=True)
    for patch in bplot1['boxes']:
        patch.set(color=colors[0])
        patch.set(facecolor=colors[0])
        patch.set(alpha = .5)
    for patch in bplot2['boxes']:
        patch.set(color=colors[1])
        patch.set(facecolor=colors[1])
        patch.set(alpha = .5)        
    plt.xlim([r1[0]-barWidth/2.5,r2[0]+barWidth/2.5])
    plt.xticks([r1[0],r2[0]],[])
    plt.ylim(top = .801)
    plt.xlabel(None)
    plt.yticks([0,0.45,0.8],[0,0.45,0.8])
    adjust_spines(ax, ['left', 'bottom'])
    plt.tight_layout()
    fileName = 'fig2B_behaHMMComb_scatter.pdf'
    # plt.savefig(fileName,dpi = 600)
plt.show()
#%%fig2C
s1 = 25
alpha1 = .5
r1 = 0
r2 = 1
r3 = 2.2
r4 = 3.2
lw1 = .25
widt1 = 0.8
inpuData1 = copy.deepcopy(dataCollVP)
inpuData2 = copy.deepcopy(dataCollVPC)
with plt.style.context('style_paper.mplstyle'):
    colors = plt.rcParams["axes.prop_cycle"].by_key()["color"]
    fig, ax = plt.subplots(figsize=(3.54/2,3.54/2))
    jets1 = [0 for i in range(len(inpuData1))]
    plt.scatter([jets1[i]+r1 for i in range(len(inpuData1))],inpuData1[:,0],
                alpha = .5,
                s = s1,
                color = [],
                marker ='o' ,
                edgecolors = colors[1]
                )

    plt.scatter([jets1[i]+r2 for i in range(len(inpuData2))],inpuData2[:,0],
                alpha = .2,
                s = s1,
                color = [],
                marker ='o' ,
                edgecolors = 'k')
    plt.scatter([jets1[i]+r3 for i in range(len(inpuData1))],inpuData1[:,2],
                alpha = .5,
                s = s1,
                color = [],
                marker ='o' ,
                edgecolors = colors[1]
                )
    plt.scatter([jets1[i]+r4 for i in range(len(inpuData2))],inpuData2[:,2],
                alpha = .2,
                s = s1,
                color = [],
                marker ='o' ,
                edgecolors = 'k')    
    bplot1 = plt.boxplot(inpuData1[:,[0]],positions = [r1],
                showfliers=False,
                widths = widt1,
                medianprops = {'color':'k'},
                patch_artist=True)
    colorsTest = colors[1]
    for patch in bplot1['boxes']:
        patch.set(color=colors[1])
        patch.set(facecolor=colors[1])
        patch.set(alpha = .5)
    bplot2 = plt.boxplot(inpuData2[:,[0]],positions = [r2],
                          medianprops = {'color':'k'},
                          widths = widt1,
                         showfliers=False,
                         patch_artist=True)
    for patch in bplot2['boxes']:
        patch.set(color='k')
        patch.set(facecolor='k')
        patch.set(alpha = .5)
    bplot3 = plt.boxplot(inpuData1[:,[2]],positions = [r3],
                         showfliers=False,
                          medianprops = {'color':'k'},
                          widths = widt1,
                         patch_artist=True)
    for patch in bplot3['boxes']:
        patch.set(color=colors[1])
        patch.set(facecolor=colors[1])
        patch.set(alpha = .5)
    bplot4 = plt.boxplot(inpuData2[:,[2]],positions = [r4],
                          medianprops = {'color':'k'},
                         showfliers=False,
                         widths = widt1,
                         patch_artist=True)
    for patch in bplot4['boxes']:
        patch.set(color='k')
        patch.set(facecolor='k')
        patch.set(alpha = .5)
    plt.xticks([((r1+r2)/2),(r3+r4)/2],['Early','Late'])
    plt.ylabel('Standard deviation (deg)')
    adjust_spines(ax, ['left', 'bottom'])
    fig.tight_layout()
    fileName = 'fig2C_behaAftereffectScatter.pdf'
    # plt.savefig(fileName,dpi = 600)
plt.show()

with plt.style.context('style_paper.mplstyle'):
    colors = plt.rcParams["axes.prop_cycle"].by_key()["color"]
    plt.figure(figsize=(3.54/1.8,3.54/2))
    plt.hist(examVpcPost,density = True,alpha = .5,bins = 5,color ='k')#color = 'k',
    plt.hist(examVpPost,bins = 5,density = True,alpha = .5,color =colors[1])#,color =colors[0]     
    plt.ylabel('Trial counts (Fraction)')
    plt.xlabel('Degree')
    # plt.show()
    driftVPDens =  stats.gaussian_kde(examVpPost)
    poptVP,pcovVP = curve_fit(gaus,examVpPost,driftVPDens(examVpPost),p0=[1,0,1])
    xvp = np.linspace(-10, 10, 1000)
    yVP = gaus(xvp,poptVP[0],poptVP[1],poptVP[2])        
    driftVPCDens =  stats.gaussian_kde(examVpcPost)
    poptVPC,pcovVPC = curve_fit(gaus,examVpcPost,driftVPCDens(examVpcPost),p0=[1,0,1])
    xvpc = np.linspace(-10, 10, 1000)
    yVPC = gaus(xvpc,poptVPC[0],poptVPC[1],poptVPC[2])
    
    plt.plot(xvp,yVP,color =colors[1],linewidth = 2,label = 'After VP')#colors[0]
    plt.plot(xvpc,yVPC,color = 'k',alpha = 1,linewidth = 2,label = 'After VPC')
    plt.legend(bbox_to_anchor=[.54,.65])
    plt.xlim([-10,10])
    plt.tight_layout()
    fileName = 'fig2C_behaAftereffectExamSess.pdf'
    # plt.savefig(fileName,dpi = 600)
plt.show()

#%%fig2D
inpuData = copy.deepcopy(VPVPCSDNormClean)
tansDF = pd.DataFrame()
widt1 = .15
barWidth = 0.2
r1 = np.arange(1)+barWidth
r2 = np.array([x + barWidth for x in r1])
s1 = 30
lw1 = .25
with plt.style.context('style_paper.mplstyle'):
    colors = plt.rcParams["axes.prop_cycle"].by_key()["color"]
    fig, ax = plt.subplots(figsize=(3.54/2,3.54/2))

    jets1 = [0 for i in range(len(inpuData))]

    
    plt.scatter([jets1[i]+r1 for i in range(len(inpuData))],inpuData[:,0],
                alpha = .5,
                s = s1,
                color = [],
                marker = 'o',
                edgecolors = colors[0]
                )
    plt.scatter([jets1[i]+r2 for i in range(len(inpuData))],inpuData[:,1],
                alpha = .5,
                s = s1,
                color = [],
                marker = 'o',
                edgecolors = 'k')

    bplot1 = plt.boxplot(inpuData[:,[0]],positions = r1,
                showfliers=False,
                widths = widt1,
                medianprops = {'color':'k'},
                patch_artist=True)
    bplot2 = plt.boxplot(inpuData[:,[1]],positions = r2,
                showfliers=False,
                widths = widt1,
                medianprops = {'color':'k',
                               },
                patch_artist=True)
    for patch in bplot1['boxes']:
        patch.set(color=colors[0])
        patch.set(facecolor=colors[0])
        patch.set(alpha = .5)
    for patch in bplot2['boxes']:
        patch.set(color='k')
        patch.set(facecolor='k')
        patch.set(alpha = .5)
    plt.xticks([r1[0],r2[0]],['VP','VPC (0°)'])
    plt.ylabel('Standard deviation (deg)')        
    plt.xlim([r1[0]-barWidth/2.5,r2[0]+barWidth/2.5])
    plt.xlabel(None)
    adjust_spines(ax, ['left', 'bottom'])
    plt.tight_layout()
    fileName = 'fig2D_behaVPVPCDrifSDMeanCombRaw.pdf'
    # plt.savefig(fileName,dpi = 600)
plt.show()

with plt.style.context('style_paper.mplstyle'):
    colors = plt.rcParams["axes.prop_cycle"].by_key()["color"]
    plt.figure(figsize=(3.54/1.8,3.54/2))
    plt.hist(examVpcWithin,density = True,alpha = .5,bins = 5,color ='k')#color = 'k',
    plt.hist(examVpWithin,bins = 5,density = True,alpha = .5,color =colors[0])#,color =colors[0] 
    plt.ylabel('Trial counts (Fraction)')
    plt.xlabel('Degree')
    driftVPDens =  stats.gaussian_kde(examVpWithin)
    poptVP,pcovVP = curve_fit(gaus,examVpWithin,driftVPDens(examVpWithin),p0=[1,0,1])
    xvp = np.linspace(-10, 10, 1000)
    yVP = gaus(xvp,poptVP[0],poptVP[1],poptVP[2])        
    driftVPCDens =  stats.gaussian_kde(examVpcWithin)
    poptVPC,pcovVPC = curve_fit(gaus,examVpcWithin,driftVPCDens(examVpcWithin),p0=[1,0,1])
    xvpc = np.linspace(-10, 10, 1000)
    yVPC = gaus(xvpc,poptVPC[0],poptVPC[1],poptVPC[2])
    
    plt.plot(xvp,yVP,color =colors[0],linewidth = 2,label = 'VP')#colors[0]
    plt.plot(xvpc,yVPC,color = 'k',alpha = 1,linewidth = 2,label = 'VPC (0°)')
    plt.legend(bbox_to_anchor=[.54,.8])
    plt.xlim([-5,5])
    plt.tight_layout()
    fileName = 'fig2D_example20121214SealVPVPC.pdf'
    # plt.savefig(fileName,dpi = 600)
plt.show()
